# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import random,os

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
from PIL import ImageFilter
import os
import torch
from torch.utils.data import Dataset
from PIL import Image

import torchvision.transforms.v2 as transforms
import torchvision.datasets as datasets
import tonic
from tonic import DiskCachedDataset ,MemoryCachedDataset
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter
from decord import VideoReader
from decord import cpu, gpu

from torchcodec.decoders import VideoDecoder
import lmdb
import pickle
import time

# Define helper functions for NCALTECH101 dataset transformations
def to_tensor(x):
    return torch.tensor(x, dtype=torch.float)
class NormalizePixelValues(object):
    def __call__(self, sample):
        return sample / 255.0
class ToTensor(object):
    def __call__(self, sample):
        return torch.from_numpy(sample) / 255.0
    
 
class Kinetics400VideoDataset(Dataset):
    def __init__(
        self, root, num_frames=16, interval=1, mode='train', 
        num_clips=1, transform=None,
        channel_first=False
    ):
        """
        Args:
            videos_root: 视频根目录，包含各个类别文件夹
            num_frames: 每个clip采样的帧数
            interval: 帧间隔
            mode: 'train' 或 'test'/'val'
            num_clips: 测试时采样的clip数量
            transform: 数据变换
        """
        # 自动构建类别到索引的映射
        self.class2idx = {}
        class_names = sorted([d for d in os.listdir(root) 
                            if os.path.isdir(os.path.join(root, d))])
        for idx, class_name in enumerate(class_names):
            self.class2idx[class_name] = idx
            
        # 扫描所有视频文件
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root, class_name)
            video_files = [f for f in os.listdir(class_dir) 
                          if f.endswith(('.mp4', '.avi', '.mkv', '.mov', '.webm'))]
            
            for video_file in video_files:
                video_path = os.path.join(class_dir, video_file)
                label = self.class2idx[class_name]
                self.samples.append((video_path, label))
        
        self.num_frames = num_frames
        self.interval = interval
        self.mode = mode
        self.num_clips = num_clips if mode != 'train' else 1  # 训练只采1个clip
        self.transform = transform
        self.classes = list(self.class2idx.keys())
        self.channel_first = channel_first
        print(f"Found {len(self.samples)} videos in {len(self.classes)} classes")

    def __len__(self):
        return len(self.samples)

    def sample_train_indices(self, total_frames, num_frames, interval):
        """训练时随机采样帧索引"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        start = torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0
        indices = [start + i * interval for i in range(num_frames)]
        indices = [min(idx, total_frames - 1) for idx in indices]
        return indices

    def sample_multi_clip_indices(self, total_frames, num_frames, interval, num_clips=10):
        """测试时多clip采样"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        if num_clips == 1:
            starts = [max_start // 2]
        else:
            starts = [int(max_start * float(i) / (num_clips - 1)) for i in range(num_clips)]
        
        all_indices = []
        for start in starts:
            indices = [start + i * interval for i in range(num_frames)]
            indices = [min(idx, total_frames - 1) for idx in indices]
            all_indices.append(indices)
        return all_indices

    def read_video_decord(self, video_path, indices):
        """使用decord读取视频帧"""
         
        decoder = VideoDecoder(video_path)
        frames = decoder.get_frames_at(indices)
        
        return frames.data
            


    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        
        try:
            # 先获取视频总帧数
            vr = VideoDecoder(video_path)
            try:
                total_frames = vr.metadata.num_frames
            except:
                total_frames = len(vr)
            # 根据模式采样帧索引
            if self.mode == 'train':
                indices_list = [self.sample_train_indices(total_frames, self.num_frames, self.interval)]
            else:
                indices_list = self.sample_multi_clip_indices(
                    total_frames, self.num_frames, self.interval, self.num_clips
                )
            
            # 读取视频clip(s)
            clips = []
            for indices in indices_list:
                frames = self.read_video_decord(video_path, indices)
                
                # 应用变换
                if self.transform:
                    try:
                        frames = self.transform(frames)
                    except Exception as e:
                        print(f"Transform error: {e}")
                        print(f"Frames shape: {frames.shape}, dtype: {frames.dtype}")
                        raise e
                
                clips.append(frames)
            
            # 处理返回格式
            if len(clips) == 1:
                clips = clips[0]  # [T, C, H, W]
            else:
                clips = torch.stack(clips, dim=0)  # [num_clips, T, C, H, W]
            if self.channel_first:
                clips = clips.permute(1, 0, 2, 3) if clips.dim() == 4 else clips.permute(0, 2, 1, 3, 4)
            return clips, label
            
        except Exception as e:
            # print(f"Error processing video {video_path}: {e}")
            # 返回全零tensor作为备选
            if self.mode == 'train':
                zero_frames = torch.zeros(self.num_frames, 3, 224, 224)
            else:
                zero_frames= torch.zeros(self.num_clips, self.num_frames, 3, 224, 224)
                if self.num_clips == 1:
                    zero_frames = zero_frames[0]
                            # 应用变换
            if self.transform:
                try:
                    zero_frames = self.transform(zero_frames)
                except Exception as e:
                    print(f"Transform error: {e}")
                    print(f"Frames shape: {zero_frames.shape}, dtype: {zero_frames.dtype}")
                    raise e
            if self.channel_first:
                zero_frames = zero_frames.permute(1, 0, 2, 3) if zero_frames.dim() == 4 else zero_frames.permute(0, 2, 1, 3, 4)
            return zero_frames, label

class Kinetics400LMDBDataset(Dataset):
    def __init__(
        self, lmdb_path, num_frames=16, interval=1, mode='train', 
        num_clips=1, transform=None, channel_first=False
    ):
        """
        Args:
            lmdb_path: LMDB数据库路径
            num_frames: 每个clip采样的帧数
            interval: 帧间隔
            mode: 'train' 或 'test'/'val'
            num_clips: 测试时采样的clip数量
            transform: 数据变换
            channel_first: 是否channel在前
        """
        self.lmdb_path = lmdb_path
        self.num_frames = num_frames
        self.interval = interval
        self.mode = mode
        self.num_clips = num_clips if mode != 'train' else 1  # 训练只采1个clip
        self.transform = transform
        self.channel_first = channel_first
        
        # 打开LMDB环境
        self.env = lmdb.open(str(lmdb_path), readonly=True, lock=False)
        
        # 读取元数据
        with self.env.begin() as txn:
            metadata = pickle.loads(txn.get(b'__metadata__'))
        
        self.num_videos = metadata['successful_videos']
        self.classes = metadata['classes']
        self.class2idx = metadata['class_to_idx']
        self.idx2class = metadata['idx_to_class']
        self.env.close()
        self.env = None

        print(f"Loaded LMDB dataset with {self.num_videos} videos in {len(self.classes)} classes")

    def __len__(self):
        return self.num_videos

    def sample_train_indices(self, total_frames, num_frames, interval):
        """训练时随机采样帧索引"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        start = torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0
        indices = [start + i * interval for i in range(num_frames)]
        indices = [min(idx, total_frames - 1) for idx in indices]
        return indices

    def sample_multi_clip_indices(self, total_frames, num_frames, interval, num_clips=10):
        """测试时多clip采样"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        if num_clips == 1:
            starts = [max_start // 2]
        else:
            starts = [int(max_start * float(i) / (num_clips - 1)) for i in range(num_clips)]
        
        all_indices = []
        for start in starts:
            indices = [start + i * interval for i in range(num_frames)]
            indices = [min(idx, total_frames - 1) for idx in indices]
            all_indices.append(indices)
        return all_indices

    def decode_frames_from_lmdb(self, frames_bytes, indices):
        """从LMDB中解码指定索引的帧"""
        frames = []
        for idx in indices:
            if idx < len(frames_bytes):
                frame_bytes = frames_bytes[idx]
                # 解码JPEG字节为numpy数组
                frame_array = np.frombuffer(frame_bytes, dtype=np.uint8)
                frame = cv2.imdecode(frame_array, cv2.IMREAD_COLOR)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # 转换为RGB
                frames.append(frame)
            else:
                # 如果索引超出范围，复制最后一帧
                if frames:
                    frames.append(frames[-1])
                else:
                    # 创建黑色帧作为默认
                    frames.append(np.zeros((224, 224, 3), dtype=np.uint8))
        
        # 转换为torch tensor
        frames = np.stack(frames, axis=0)  # [T, H, W, C]
        frames = torch.from_numpy(frames)  # [T, C, H, W]
        frames = frames.permute(0, 3, 1, 2) # [T, C, H, W]

        return frames

    def __getitem__(self, idx):
        try:
            global G_get_time,G_pickle_time,G_index_time,G_before_decode_time,G_after_decode_time,G_after_transform_time,G_end_time,G_get_time_count,G_pickle_time_count,G_index_time_count,G_before_decode_time_count,G_after_decode_time_count,G_after_transform_time_count,G_end_time_count
            if self.env is None:
                self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False,readahead=True, meminit=False,max_readers=100000)

            # 从LMDB读取数据
            with self.env.begin() as txn:
                key = f'{idx:08d}'.encode('ascii')
                data = txn.get(key)

                if data is None:
                    raise ValueError(f"Sample {idx} not found in LMDB")
                
                sample = pickle.loads(data)

            frames_bytes = sample['frames']
            label = sample['label']
            total_frames = sample['num_frames']

            # 根据模式采样帧索引
            if self.mode == 'train':
                indices_list = [self.sample_train_indices(total_frames, self.num_frames, self.interval)]
            else:
                indices_list = self.sample_multi_clip_indices(
                    total_frames, self.num_frames, self.interval, self.num_clips
                )
            
            # 解码视频clip(s)
            clips = []
            for indices in indices_list:

                frames = self.decode_frames_from_lmdb(frames_bytes, indices)

                # 应用变换
                if self.transform:
                    try:
                        frames = self.transform(frames)
                    except Exception as e:
                        print(f"Transform error: {e}")
                        print(f"Frames shape: {frames.shape}, dtype: {frames.dtype}")
                        raise e

                clips.append(frames)
            
            # 处理返回格式
            if len(clips) == 1:
                clips = clips[0]  # [T, C, H, W]
            else:
                clips = torch.stack(clips, dim=0)  # [num_clips, T, C, H, W]
                
            if self.channel_first:
                clips = clips.permute(1, 0, 2, 3) if clips.dim() == 4 else clips.permute(0, 2, 1, 3, 4)



            return clips, label
            
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # 返回全零tensor作为备选
            if self.mode == 'train':
                zero_frames = torch.zeros(self.num_frames, 3, 224, 224)
            else:
                zero_frames = torch.zeros(self.num_clips, self.num_frames, 3, 224, 224)
                if self.num_clips == 1:
                    zero_frames = zero_frames[0]
            
            # 应用变换
            if self.transform:
                try:
                    zero_frames = self.transform(zero_frames)
                except Exception as e:
                    print(f"Transform error: {e}")
                    print(f"Frames shape: {zero_frames.shape}, dtype: {zero_frames.dtype}")
                    raise e
                    
            if self.channel_first:
                zero_frames = zero_frames.permute(1, 0, 2, 3) if zero_frames.dim() == 4 else zero_frames.permute(0, 2, 1, 3, 4)
                
            return zero_frames, 0  # 返回标签0作为默认    

    
    def __del__(self):
        """关闭LMDB环境"""
        if hasattr(self, 'env') and self.env is not None:
            self.env.close()

class Kinetics400FrameLMDBDataset(Dataset):
    def __init__(
        self, lmdb_path, num_frames=16, interval=1, mode='train', 
        num_clips=1, transform=None, channel_first=False
    ):
        """
        Args:
            lmdb_path: 帧级LMDB数据库路径
            num_frames: 每个clip采样的帧数
            interval: 帧间隔
            mode: 'train' 或 'test'/'val'
            num_clips: 测试时采样的clip数量
            transform: 数据变换
            channel_first: 是否channel在前
        """
        self.lmdb_path = lmdb_path
        self.num_frames = num_frames
        self.interval = interval
        self.mode = mode
        self.num_clips = num_clips if mode != 'train' else 1  # 训练只采1个clip
        self.transform = transform
        self.channel_first = channel_first
        
        # 打开LMDB环境
        self.env = lmdb.open(str(lmdb_path), readonly=True, lock=False)
        
        # 读取元数据
        with self.env.begin() as txn:
            metadata = pickle.loads(txn.get(b'__metadata__'))
        
        self.num_videos = metadata['successful_videos']
        self.classes = metadata['classes']
        self.class2idx = metadata['class_to_idx']
        self.idx2class = metadata['idx_to_class']
        
        # 预加载所有视频的元信息（用于快速获取帧数等信息）
        self.video_metas = {}
        print("Loading video metadata...")
        with self.env.begin() as txn:
            for video_idx in range(self.num_videos):
                meta_key = f'meta_{video_idx:08d}'.encode('ascii')
                meta_data = txn.get(meta_key)
                if meta_data is not None:
                    self.video_metas[video_idx] = pickle.loads(meta_data)
        
        self.env.close()
        self.env = None
        
        print(f"Loaded frame-based LMDB dataset with {self.num_videos} videos in {len(self.classes)} classes")
        print(f"Total frames: {metadata.get('total_frames', 'Unknown')}")

    def __len__(self):
        return self.num_videos
    
    def sample_train_indices(self, total_frames, num_frames, interval):
        """训练时随机采样帧索引"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        start = torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0
        indices = [start + i * interval for i in range(num_frames)]
        indices = [min(idx, total_frames - 1) for idx in indices]
        return indices

    def sample_multi_clip_indices(self, total_frames, num_frames, interval, num_clips=10):
        """测试时多clip采样"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        if num_clips == 1:
            starts = [max_start // 2]
        else:
            starts = [int(max_start * float(i) / (num_clips - 1)) for i in range(num_clips)]
        
        all_indices = []
        for start in starts:
            indices = [start + i * interval for i in range(num_frames)]
            indices = [min(idx, total_frames - 1) for idx in indices]
            all_indices.append(indices)
        return all_indices

    def load_frames_from_lmdb(self, video_idx, frame_indices):
        """从帧级LMDB加载指定视频的指定帧"""
        frames = []
        cv2.ocl.setUseOpenCL(False)   
        cv2.setNumThreads(0)  
        if self.env is None:
            self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False, 
                               readahead=True, meminit=False, max_readers=100000)
        
        with self.env.begin() as txn:
            for frame_idx in frame_indices:
                # 构建复合key，格式：video_idx(8d)_frame_idx(4d)
                key = f'{video_idx:08d}_{frame_idx:04d}'.encode('ascii')
                frame_data = txn.get(key)
                
                if frame_data is not None:
                    # 解析帧数据
                    frame_item = pickle.loads(frame_data)
                    frame_bytes = frame_item['frame_data']
                    
                    # 解码JPEG
                    frame_array = np.frombuffer(frame_bytes, dtype=np.uint8)
                    frame = cv2.imdecode(frame_array, cv2.IMREAD_COLOR)
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # 转换为RGB
                    frames.append(frame)
                else:
                    # 如果帧不存在，使用备选策略
                    if frames:
                        frames.append(frames[-1])  # 重复最后一帧
                    else:
                        # 创建黑色帧作为默认
                        frames.append(np.zeros((224, 224, 3), dtype=np.uint8))
        
        # 转换为torch tensor
        frames = np.stack(frames, axis=0)  # [T, H, W, C]
        frames = torch.from_numpy(frames)  # [T, H, W, C]
        frames = frames.permute(0, 3, 1, 2)  # [T, C, H, W]
        
        return frames

    def __getitem__(self, idx):
        try:
            # 获取视频元信息
            if idx not in self.video_metas:
                raise ValueError(f"Video {idx} metadata not found")
            video_meta = self.video_metas[idx]
            total_frames = video_meta['num_frames']
            label = video_meta['label']
            
            # 根据模式采样帧索引
            if self.mode == 'train':
                indices_list = [self.sample_train_indices(total_frames, self.num_frames, self.interval)]
            else:
                indices_list = self.sample_multi_clip_indices(
                    total_frames, self.num_frames, self.interval, self.num_clips
                )
            
            # 加载视频clip(s)
            clips = []
            for indices in indices_list:
                frames = self.load_frames_from_lmdb(idx, indices)
                # 应用变换
                if self.transform:
                    try:
                        frames = self.transform(frames)
                    except Exception as e:
                        print(f"Transform error: {e}")
                        print(f"Frames shape: {frames.shape}, dtype: {frames.dtype}")
                        raise e
                clips.append(frames)
            
            # 处理返回格式
            if len(clips) == 1:
                clips = clips[0]  # [T, C, H, W]
            else:
                clips = torch.stack(clips, dim=0)  # [num_clips, T, C, H, W]
                
            if self.channel_first:
                clips = clips.permute(1, 0, 2, 3) if clips.dim() == 4 else clips.permute(0, 2, 1, 3, 4)

            return clips, label
            
        except Exception as e:
            print(f"Error processing video {idx}: {e}")
            # 返回全零tensor作为备选
            if self.mode == 'train':
                zero_frames = torch.zeros(self.num_frames, 3, 224, 224)
            else:
                zero_frames = torch.zeros(self.num_clips, self.num_frames, 3, 224, 224)
                if self.num_clips == 1:
                    zero_frames = zero_frames[0]
            
            # 应用变换
            if self.transform:
                try:
                    zero_frames = self.transform(zero_frames)
                except Exception as e:
                    print(f"Transform error: {e}")
                    print(f"Frames shape: {zero_frames.shape}, dtype: {zero_frames.dtype}")
                    raise e
                    
            if self.channel_first:
                zero_frames = zero_frames.permute(1, 0, 2, 3) if zero_frames.dim() == 4 else zero_frames.permute(0, 2, 1, 3, 4)
                
            return zero_frames, 0  # 返回标签0作为默认

    def get_video_info(self, idx):
        """获取指定视频的详细信息"""
        return self.video_metas.get(idx, None)

    def get_class_names(self):
        """获取类别名称列表"""
        return self.classes

    def __del__(self):
        """关闭LMDB环境"""
        if hasattr(self, 'env') and self.env is not None:
            self.env.close()


class UCF101FramesDatasetV2(Dataset):
    def __init__(
        self, frames_root, split_txt, class_ind_txt,
        num_frames=16, interval=1, mode='train', num_clips=1, transform=None,channel_first=False,multi_train_sample=False
    ):
        self.class2idx = {}
        with open(class_ind_txt) as f:
            for line in f:
                idx, classname = line.strip().split()
                self.class2idx[classname] = int(idx) - 1

        self.samples = []
        with open(split_txt) as f:
            for line in f:
                path = line.strip().split(' ')[0]
                class_name = path.split('/')[0]
                video_name = os.path.splitext(os.path.basename(path))[0]
                frame_dir = os.path.join(frames_root, class_name, video_name)
                label = self.class2idx[class_name]
                self.samples.append((frame_dir, label))

        self.num_frames = num_frames
        self.interval = interval
        self.mode = mode
        self.num_clips = num_clips 
        self.transform = transform
        self.channel_first = channel_first
        self.classes = list(self.class2idx.keys())
        self.multi_train_sample = multi_train_sample
    def __len__(self):
        return len(self.samples)
    
    def sample_train_indices(self,total_frames, num_frames, interval):
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        start = torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0
        indices = [start + i * interval for i in range(num_frames)]
        indices = [min(idx, total_frames - 1) for idx in indices]
        return indices

    def sample_multi_clip_indices(self, total_frames, num_frames, interval, num_clips=10):
        """测试时多clip采样"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        if num_clips == 1:
            starts = [max_start // 2]
        elif self.multi_train_sample and self.mode=='train':
            starts = [(torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0) for _ in range(num_clips)]

            starts.sort()
        else:
            starts = [int(max_start * float(i) / (num_clips - 1)) for i in range(num_clips)]
        
        all_indices = []
        for start in starts:
            indices = [start + i * interval for i in range(num_frames)]
            indices = [min(idx, total_frames - 1) for idx in indices]
            all_indices.append(indices)
        return all_indices
    
    def read_img_cv2(self,path):
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return torch.from_numpy(img).permute(2, 0, 1)
    
    def __getitem__(self, idx):
        frame_dir, label = self.samples[idx]
        frame_files = sorted([f for f in os.listdir(frame_dir) if f.endswith('.jpg') or f.endswith('.png')])
        total_frames = len(frame_files)

        if self.mode == 'train' and not self.multi_train_sample:
            indices_list = [self.sample_train_indices(total_frames, self.num_frames, self.interval)]
        else:
            indices_list = self.sample_multi_clip_indices(total_frames, self.num_frames, self.interval, self.num_clips)

        clips = []
        for indices in indices_list:
            frames = [self.read_img_cv2(os.path.join(frame_dir, frame_files[i])) for i in indices]
            frames = torch.stack(frames, axis=0)  # [T, C, H, W]
            if self.transform:
                try:
                    frames = self.transform(frames)  # transforms.v2 可 clip 级同步
                except Exception as e:
                    print(e)
                    print(frames.shape,frames.dtype,frames.max(),type(frames))
                    raise e
            clips.append(frames)
        if len(clips) == 1:
            clips = clips[0]# [T, C, H, W]
        else:
            if isinstance(clips[0], torch.Tensor):
                clips = torch.stack(clips, dim=0)  # [num_clips, T, C, H, W]
            else:
                clips = [torch.stack([clip[i] for clip in clips], dim=0) for i in range(len(clips[0]))] 
 
        if self.channel_first:
            if type(clips) == torch.Tensor:
                clips = clips.permute(1, 0, 2, 3)
            else:
                clips = [clip.permute(1, 0, 2, 3) for clip in clips] 

        return clips, label


class HMDB51FramesDatasetV2(Dataset):
    def __init__(
        self, frames_root, split_dir, split_num=1,
        num_frames=16, interval=1, mode='train', num_clips=1, transform=None
    ):
        """
        Args:
            frames_root: 帧图像根目录
            split_dir: 包含所有split文件的目录 (如 testTrainMulti_7030_splits/)
            split_num: 使用哪个split (1, 2, 或 3)
            mode: 'train' 或 'test'
        """
        # HMDB51类别映射 (51个类别)
        self.class_names = [
            'brush_hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb_stairs',
            'dive', 'draw_sword', 'dribble', 'drink', 'eat', 'fall_floor', 'fencing',
            'flic_flac', 'golf', 'handstand', 'hit', 'hug', 'jump', 'kick', 'kick_ball',
            'kiss', 'laugh', 'pick', 'pour', 'pullup', 'punch', 'push', 'pushup',
            'ride_bike', 'ride_horse', 'run', 'shake_hands', 'shoot_ball', 'shoot_bow',
            'shoot_gun', 'sit', 'situp', 'smile', 'smoke', 'somersault', 'stand',
            'swing_baseball', 'sword', 'sword_exercise', 'talk', 'throw', 'turn',
            'walk', 'wave'
        ]
        
        self.class2idx = {name: idx for idx, name in enumerate(self.class_names)}
        
        self.samples = []
        
        # 读取HMDB51官方split文件
        for class_name in self.class_names:
            split_file = os.path.join(split_dir, f"{class_name}_test_split{split_num}.txt")
            if not os.path.exists(split_file):
                print(f"Warning: Split file not found: {split_file}")
                continue
                
            with open(split_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) != 2:
                        continue
                    video_name = parts[0]
                    split_label = int(parts[1])
                    
                    # split_label: 1=train, 2=test, 0=unused
                    if (mode == 'train' and split_label == 1) or (mode == 'test' and split_label == 2):
                        # 去掉视频文件扩展名
                        video_name_no_ext = os.path.splitext(video_name)[0]
                        frame_dir = os.path.join(frames_root, class_name, video_name_no_ext)
                        
                        # 检查帧目录是否存在
                        if os.path.exists(frame_dir):
                            label = self.class2idx[class_name]
                            self.samples.append((frame_dir, label))
        
        print(f"HMDB51 {mode} split {split_num}: {len(self.samples)} videos loaded")
        
        self.num_frames = num_frames
        self.interval = interval
        self.mode = mode
        self.num_clips = num_clips if mode != 'train' else 1
        self.transform = transform
        self.classes = list(self.class2idx.keys())

    def __len__(self):
        return len(self.samples)
    
    def sample_train_indices(self, total_frames, num_frames, interval):
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        start = torch.randint(0, max_start + 1, (1,)).item() if max_start > 0 else 0
        indices = [start + i * interval for i in range(num_frames)]
        indices = [min(idx, total_frames - 1) for idx in indices]
        return indices
    
    def sample_multi_clip_indices(self, total_frames, num_frames, interval, num_clips=10):
        """测试时多clip采样"""
        max_start = max(0, total_frames - (num_frames - 1) * interval)
        if num_clips == 1:
            starts = [max_start // 2]
        else:
            starts = [int(max_start * float(i) / (num_clips - 1)) for i in range(num_clips)]
        
        all_indices = []
        for start in starts:
            indices = [start + i * interval for i in range(num_frames)]
            indices = [min(idx, total_frames - 1) for idx in indices]
            all_indices.append(indices)
        return all_indices
    
    def read_img_cv2(self, path):
        img = cv2.imread(path)
        if img is None:
            raise ValueError(f"Cannot read image: {path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return torch.from_numpy(img).permute(2, 0, 1)
    
    def __getitem__(self, idx):
        frame_dir, label = self.samples[idx]
        
        # 获取所有帧文件
        frame_files = sorted([f for f in os.listdir(frame_dir) 
                            if f.endswith('.jpg') or f.endswith('.png')])
        
        if len(frame_files) == 0:
            raise ValueError(f"No frame files found in {frame_dir}")
        
        total_frames = len(frame_files)
        
        if self.mode == 'train':
            indices_list = [self.sample_train_indices(total_frames, self.num_frames, self.interval)]
        else:
            indices_list = self.sample_multi_clip_indices(total_frames, self.num_frames, self.interval, self.num_clips)
        
        clips = []
        for indices in indices_list:
            frames = []
            for i in indices:
                frame_path = os.path.join(frame_dir, frame_files[i])
                frame = self.read_img_cv2(frame_path)
                frames.append(frame)
            
            frames = torch.stack(frames, axis=0)  # [T, C, H, W]
            
            if self.transform:
                try:
                    frames = self.transform(frames)
                except Exception as e:
                    print(f"Transform error: {e}")
                    print(f"Frames shape: {frames.shape}, dtype: {frames.dtype}, max: {frames.max()}")
                    raise e
            
            clips.append(frames)
        
        if len(clips) == 1:
            clips = clips[0]  # [T, C, H, W]
        else:
            clips = torch.stack(clips, dim=0)  # [num_clips, T, C, H, W]
        
        return clips, label

class RandomFlipSequence:
    def __init__(self, p=0.5, dim=3):
        self.p = p
        self.dim = dim
    def __call__(self, x):
        if random.random() < self.p:
            return torch.flip(x, dims=(self.dim,))
        return x

class RandomRollSequence:
    def __init__(self, max_shift=5):
        self.max_shift = max_shift
    def __call__(self, x):
        sh = random.randint(-self.max_shift, self.max_shift)
        sw = random.randint(-self.max_shift, self.max_shift)
        return torch.roll(x, shifts=(sh, sw), dims=(2, 3))
        
class GaussianBlurEvents:
    """
    Apply Gaussian blur to event data tensor.
    Args:
        sigma: Standard deviation for Gaussian kernel or range [min, max]
    """
    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma
    
    def __call__(self, x):
        if isinstance(self.sigma, (list, tuple)):
            sigma = random.uniform(self.sigma[0], self.sigma[1])
        else:
            sigma = self.sigma
        
        # Handle different tensor dimensions
        original_shape = x.shape
        original_device = x.device
        
        # Move to CPU for numpy operations
        x_np = x.cpu().numpy()
        
        # Apply Gaussian filter to spatial dimensions
        if len(original_shape) == 3:  # [C, H, W]
            for c in range(original_shape[0]):
                x_np[c] = gaussian_filter(x_np[c], sigma=sigma)
        elif len(original_shape) == 4:  # [N, C, H, W]
            for n in range(original_shape[0]):
                for c in range(original_shape[1]):
                    x_np[n, c] = gaussian_filter(x_np[n, c], sigma=sigma)
        
        # Convert back to tensor
        return torch.from_numpy(x_np).to(original_device)

class ColorJitterEvents:
    """
    Apply color jitter to event data tensor.
    Args:
        brightness: How much to jitter brightness. brightness_factor is chosen 
            uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]
        contrast: How much to jitter contrast. contrast_factor is chosen
            uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]
        saturation: Not applicable for event data, kept for API compatibility
        hue: Not applicable for event data, kept for API compatibility
    """
    def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1):
        self.brightness = self._check_input(brightness)
        self.contrast = self._check_input(contrast)
        # Saturation and hue are not applicable for event data (grayscale)
        # but kept for API compatibility with torchvision
        self.saturation = self._check_input(saturation)
        self.hue = self._check_input(hue)
    
    def _check_input(self, value):
        if isinstance(value, (list, tuple)) and len(value) == 2:
            if value[0] < value[1]:
                return value
            else:
                raise ValueError(f"Interval values should be increasing, got {value}")
        elif isinstance(value, (int, float)):
            if value >= 0:
                return [max(0, 1 - value), 1 + value]
            else:
                raise ValueError(f"Value should be positive, got {value}")
        else:
            raise TypeError(f"Value should be a single number or a list/tuple with length 2, got {value}")
    
    def _get_params(self):
        brightness_factor = random.uniform(self.brightness[0], self.brightness[1])
        contrast_factor = random.uniform(self.contrast[0], self.contrast[1])
        return brightness_factor, contrast_factor
    
    def __call__(self, x):
        brightness_factor, contrast_factor = self._get_params()
        
        # Handle different tensor dimensions
        original_shape = x.shape
        original_device = x.device
        
        # Apply brightness adjustment (multiplicative)
        x = x * brightness_factor
        
        # Apply contrast adjustment
        if contrast_factor != 1:
            if len(original_shape) == 3:  # [C, H, W]
                mean = x.mean(dim=(1, 2), keepdim=True)
                x = (x - mean) * contrast_factor + mean
            elif len(original_shape) == 4:  # [N, C, H, W]
                mean = x.mean(dim=(2, 3), keepdim=True)
                x = (x - mean) * contrast_factor + mean
        

        
        return x

class RandomGrayscaleEvents:
    """
    Apply random grayscale transformation to event data tensor.
    Args:
        p: Probability of applying the grayscale transformation
    """
    def __init__(self, p=0.2):
        self.p = p
    
    def __call__(self, x):
        if random.random() >= self.p:
            return x
        
        original_shape = x.shape
        original_device = x.device
        
        # Convert to grayscale by averaging across channels
        if len(original_shape) == 3:  # [C, H, W]
            # Average across the channel dimension
            gray = x.mean(dim=0, keepdim=True)
            # Repeat to maintain original channel count
            x = gray.repeat(original_shape[0], 1, 1)
        elif len(original_shape) == 4:  # [N, C, H, W]
            # Average across the channel dimension for each sample
            gray = x.mean(dim=1, keepdim=True)
            # Repeat to maintain original channel count
            x = gray.repeat(1, original_shape[1], 1, 1)
        
        return x

def multi_transforms(x): 
    
    # lambda x: drop(x, 0.15),
    # lambda x: ShearX(x, 15),
    # lambda x: ShearY(x, 15),
    # lambda x: TranslateX(x, 0.225),
    # lambda x: TranslateY(x, 0.225),
    # lambda x: CutoutAbs(x, 0.25),
    # lambda x: CutoutTemporal(x, 0.25),
    # lambda x: GaussianBlur(x, 0.5),
    # lambda x: SaltAndPepperNoise(x, 0.1,

    # Apply Gaussian blur with 50% probability to DVS event data
    if random.random() < 0.5:
        blur = GaussianBlurEvents(sigma=[0.1, 2.0])
        x = blur(x)
    return x
 
class ResizeTensor:
    def __init__(self, size=48):
        self.size = size
        
    def __call__(self, x):
        return F.interpolate(x, size=[self.size, self.size], mode='bilinear', align_corners=True)

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur:
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

# Move SubsetTransformDataset outside the function so it can be pickled
class SubsetTransformDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        
    def __getitem__(self, idx):
        data, target = self.dataset[self.indices[idx]]
        if self.transform is not None:
            data = self.transform(data)
        return data, target
        
    def __len__(self):
        return len(self.indices)
 
def label_transform(x):
    tonic.datasets.NCALTECH101.classes = {
        'BACKGROUND_Google': 0,
        'Faces_easy': 1,
        'Leopards': 2,
        'Motorbikes': 3,
        'accordion': 4,
        'airplanes': 5,
        'anchor': 6,
        'ant': 7,
        'barrel': 8,
        'bass': 9,
        'beaver': 10,
        'binocular': 11,
        'bonsai': 12,
        'brain': 13,
        'brontosaurus': 14,
        'buddha': 15,
        'butterfly': 16,
        'camera': 17,
        'cannon': 18,
        'car_side': 19,
        'ceiling_fan': 20,
        'cellphone': 21,
        'chair': 22,
        'chandelier': 23,
        'cougar_body': 24,
        'cougar_face': 25,
        'crab': 26,
        'crayfish': 27,
        'crocodile': 28,
        'crocodile_head': 29,
        'cup': 30,
        'dalmatian': 31,
        'dollar_bill': 32,
        'dolphin': 33,
        'dragonfly': 34,
        'electric_guitar': 35,
        'elephant': 36,
        'emu': 37,
        'euphonium': 38,
        'ewer': 39,
        'ferry': 40,
        'flamingo': 41,
        'flamingo_head': 42,
        'garfield': 43,
        'gerenuk': 44,
        'gramophone': 45,
        'grand_piano': 46,
        'hawksbill': 47,
        'headphone': 48,
        'hedgehog': 49,
        'helicopter': 50,
        'ibis': 51,
        'inline_skate': 52,
        'joshua_tree': 53,
        'kangaroo': 54,
        'ketch': 55,
        'lamp': 56,
        'laptop': 57,
        'llama': 58,
        'lobster': 59,
        'lotus': 60,
        'mandolin': 61,
        'mayfly': 62,
        'menorah': 63,
        'metronome': 64,
        'minaret': 65,
        'nautilus': 66,
        'octopus': 67,
        'okapi': 68,
        'pagoda': 69,
        'panda': 70,
        'pigeon': 71,
        'pizza': 72,
        'platypus': 73,
        'pyramid': 74,
        'revolver': 75,
        'rhino': 76,
        'rooster': 77,
        'saxophone': 78,
        'schooner': 79,
        'scissors': 80,
        'scorpion': 81,
        'sea_horse': 82,
        'snoopy': 83,
        'soccer_ball': 84,
        'stapler': 85,
        'starfish': 86,
        'stegosaurus': 87,
        'stop_sign': 88,
        'strawberry': 89,
        'sunflower': 90,
        'tick': 91,
        'trilobite': 92,
        'umbrella': 93,
        'watch': 94,
        'water_lilly': 95,
        'wheelchair': 96,
        'wild_cat': 97,
        'windsor_chair': 98,
        'wrench': 99,
        'yin_yang': 100,
    }
    x = str(x).split("'")[-2]
    return tonic.datasets.NCALTECH101.classes[x]    
 
def get_imgnet_datasets(args,linearCls=False):
    # Data loading code 
    
    if not linearCls: 
        if args.aug_plus:
            # MoCo v2's aug
            train_augmentation = [
                transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
                transforms.RandomApply(
                    [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
                    p=0.8,  # not strengthened
                ),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
        else:
            # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
            train_augmentation = [
                transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
 
        
        traindir = os.path.join(args.dataset_path , "train")
        train_dataset = datasets.ImageFolder(
            traindir, TwoCropsTransform(transforms.Compose(train_augmentation))
        )
 
        return train_dataset, None
    else:
        train_augmentation=[
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        test_augmentation=[
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        traindir = os.path.join(args.dataset_path , "train")
        train_dataset = datasets.ImageFolder(
            traindir, transforms.Compose(train_augmentation)
        )
        testdir = os.path.join(args.dataset_path , "val")
        test_dataset = datasets.ImageFolder(
            testdir, transforms.Compose(test_augmentation)
        )        
        
        return train_dataset, test_dataset
    
def get_cifar10_datasets(args,linearCls=False):  
    if not linearCls: 
        if args.aug_plus:
            train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomApply(
                    [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
                    p=0.8,  # not strengthened
                ),
                transforms.RandomGrayscale(p=0.2),
                # transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ]
        else:
            train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ]
        test_augmentation = [
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        ]
        datadir = args.dataset_path     
        train_dataset=datasets.CIFAR10(datadir, train=True, transform=TwoCropsTransform(transforms.Compose(train_augmentation)))
        test_dataset=datasets.CIFAR10(datadir, train=False, transform=transforms.Compose(test_augmentation))
        return train_dataset, test_dataset
    else:
        train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            ]
        test_augmentation = [
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            ]
    
    
    datadir = args.dataset_path     
    train_dataset=datasets.CIFAR10(datadir, train=True, transform=transforms.Compose(train_augmentation))
    knn_train_dataset=datasets.CIFAR10(datadir, train=True, transform=transforms.Compose(test_augmentation))
    test_dataset=datasets.CIFAR10(datadir, train=False, transform=transforms.Compose(test_augmentation))
    return train_dataset, knn_train_dataset, test_dataset

def get_cifar100_datasets(args,linearCls=False):  
    
    
    if not linearCls:
        if args.aug_plus:
            train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomApply(
                    [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
                    p=0.8,  # not strengthened
                ),
                transforms.RandomGrayscale(p=0.2),
                # transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ]
        else:
            train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ]
        test_augmentation = [
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        ]

        datadir = args.dataset_path    
        train_dataset=datasets.CIFAR100(datadir, train=True, transform=TwoCropsTransform(transforms.Compose(train_augmentation)))
        test_dataset=datasets.CIFAR100(datadir, train=False, transform=transforms.Compose(test_augmentation))
        return train_dataset, test_dataset
    
    else:
        train_augmentation = [
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            ]
        test_augmentation = [
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            ]
        datadir = args.dataset_path    
        train_dataset=datasets.CIFAR100(datadir, train=True, transform=transforms.Compose(train_augmentation))
        knn_train_dataset=datasets.CIFAR100(datadir, train=True, transform=transforms.Compose(test_augmentation))
        test_dataset=datasets.CIFAR100(datadir, train=False, transform=transforms.Compose(test_augmentation))
        return train_dataset, knn_train_dataset, test_dataset
    
    

def get_NCALTECH101_datasets(args, linearCls=False):  
    """
    Get NCaltech101 dataset
    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
    """

    sensor_size = tonic.datasets.NCALTECH101.sensor_size
    cls_count = [467,
                 435, 200, 798, 55, 800, 42, 42, 47, 54, 46,
                 33, 128, 98, 43, 85, 91, 50, 43, 123, 47,
                 59, 62, 107, 47, 69, 73, 70, 50, 51, 57,
                 67, 52, 65, 68, 75, 64, 53, 64, 85, 67,
                 67, 45, 34, 34, 51, 99, 100, 42, 54, 88,
                 80, 31, 64, 86, 114, 61, 81, 78, 41, 66,
                 43, 40, 87, 32, 76, 55, 35, 39, 47, 38,
                 45, 53, 34, 57, 82, 59, 49, 40, 63, 39,
                 84, 57, 35, 64, 45, 86, 59, 64, 35, 85,
                 49, 86, 75, 239, 37, 59, 34, 56, 39, 60]


    # Create train/test split since NCALTECH101 doesn't have default splits
    portion = 0.9
    size = 48
    
    # Setup train/test indices
    train_sample_index = []
    test_sample_index = []
    idx_begin = 0
    for count in cls_count:
        train_sample = round(portion * count)
        test_sample = count - train_sample
        
        train_sample_index.extend(
            list((range(idx_begin, idx_begin + train_sample)))
        )
        test_sample_index.extend(
            list(range(idx_begin + train_sample, idx_begin + train_sample + test_sample))
        )
        idx_begin += count

    
    # First, create a disk cached dataset with just ToFrame transformation
    toframe_transform = transforms.Compose([
        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=args.step)
    ])
    


    # Create a dataset that applies ToFrame and caches results
    base_dataset = tonic.datasets.NCALTECH101(args.dataset_path, transform=toframe_transform)
    
    cache_path=args.cache_path
    cached_dataset = DiskCachedDataset(
        base_dataset,
        cache_path=cache_path,
        transform=transforms.Compose([
            to_tensor,
            ResizeTensor(size),
        ]),
        target_transform=label_transform,
        num_copies=1
    ) 
    
    if not linearCls:
        # Additional transformations for training (ToFrame already applied and cached)
        train_transform = transforms.Compose([
            transforms.RandomCrop(size, padding=size // 12),
            # transforms.RandomResizedCrop(size),
            # transforms.RandomApply([ColorJitterEvents(0.4, 0.4, 0.4, 0.1)], p=0.8),
            # transforms.RandomApply([GaussianBlurEvents(sigma=[0.1, 2.0])], p=0.5),
            # RandomGrayscaleEvents(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
        ])
        
        # Create dataset with preprocessed data
        train_dataset = SubsetTransformDataset(
            cached_dataset, 
            train_sample_index,
            transform=TwoCropsTransform(train_transform)
        )
        knn_train_dataset = SubsetTransformDataset(
            cached_dataset, 
            train_sample_index
        )
        test_dataset = SubsetTransformDataset(
            cached_dataset,
            test_sample_index
        )
        train_dataset.classes = cls_count
        test_dataset.classes = cls_count
        knn_train_dataset.classes = cls_count
        return train_dataset, knn_train_dataset, test_dataset
    else:
        if "is_knn_dataset" in args and args.is_knn_dataset:
            train_transform = transforms.Compose([])
            print(args.is_knn_dataset)
        else:
            train_transform = transforms.Compose([
                transforms.RandomCrop(size, padding=size // 12),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15)
            ])
        
        
        # Create datasets with cached data and additional transforms
        train_dataset = SubsetTransformDataset(
            cached_dataset, 
            train_sample_index,
            transform=train_transform
        )
        
        test_dataset = SubsetTransformDataset(
            cached_dataset,
            test_sample_index
        )
        train_dataset.classes = cls_count
        test_dataset.classes = cls_count
        return train_dataset, test_dataset

def get_CIFAR10DVS_datasets(args, linearCls=False):  
    """
    Get NCaltech101 dataset
    http://journal.frontiersin.org/Article/10.3389/fnins.2015.00437/abstract
    """
    sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
 
    # Create train/test split since NCALTECH101 doesn't have default splits
    portion = 0.9
    size = 48
    

    # First, create a disk cached dataset with just ToFrame transformation
    toframe_transform = transforms.Compose([
        tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=args.step)
    ])
     
    base_dataset = tonic.datasets.CIFAR10DVS(args.dataset_path, transform=toframe_transform)
    cache_path_train=os.path.join(args.dataset_path,'cache_train_{}_default'.format(args.step))
    cache_path_test=os.path.join(args.dataset_path,'cache_test_{}_default'.format(args.step))
    cached_dataset_train = DiskCachedDataset(
        base_dataset,
        cache_path=cache_path_train,
        transform=transforms.Compose([
            to_tensor,
            ResizeTensor(size),
        ]),
        num_copies=1
    ) 
    cached_dataset_test = DiskCachedDataset(
        base_dataset,
        cache_path=cache_path_test,
        transform=transforms.Compose([
            to_tensor,
            ResizeTensor(size),
        ]),
        num_copies=1
    )    
    

     
    train_sample_index = []
    test_sample_index = []
    num_per_cls = len(base_dataset)// 10
 
    for i in range(10):
        train_sample_index.extend(
            list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
        test_sample_index.extend(
            list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))

    
    
    
    if not linearCls:
        # Additional transformations for training (ToFrame already applied and cached)
        train_transform = transforms.Compose([
            # transforms.RandomCrop(size, padding=size // 12),
            # transforms.RandomResizedCrop(size),
            # transforms.RandomApply([ColorJitterEvents(0.4, 0.4, 0.4, 0.1)], p=0.8),
            # transforms.RandomApply([GaussianBlurEvents(sigma=[0.1, 2.0])], p=0.5),
            # RandomGrayscaleEvents(p=0.2),
            # transforms.RandomHorizontalFlip(),
            # transforms.RandomRotation(15),
            RandomFlipSequence(p=0.5, dim=3),
            RandomRollSequence(max_shift=5),
        ])
        
        # Create dataset with preprocessed data
        train_dataset = SubsetTransformDataset(
            cached_dataset_train, 
            train_sample_index,
            transform=TwoCropsTransform(train_transform)
        )
        knn_train_dataset = SubsetTransformDataset(
            cached_dataset_train, 
            train_sample_index
        )
        test_dataset = SubsetTransformDataset(
            cached_dataset_test,
            test_sample_index,
        )
        train_dataset.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        test_dataset.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        knn_train_dataset.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        return train_dataset, knn_train_dataset, test_dataset
    else:
        # Additional transformations for training and testing
        if "is_knn_dataset" in args and args.is_knn_dataset:
            train_transform = transforms.Compose([])
            print(args.is_knn_dataset)
        else:
            train_transform = transforms.Compose([
                transforms.RandomCrop(size, padding=size // 12),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15)
            ])
        
        
        # Create datasets with cached data and additional transforms
        train_dataset = SubsetTransformDataset(
            cached_dataset_train, 
            train_sample_index,
            transform=train_transform
        )
        
        test_dataset = SubsetTransformDataset(
            cached_dataset_test,
            test_sample_index,
        )
        train_dataset.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        test_dataset.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        return train_dataset, test_dataset

def get_ucf101_datasets(args,linearCls=False):  
    
    size = args.size
    if not linearCls:
        
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.RandomResizedCrop(size, scale=(0.2, 0.766), ratio=(0.75, 1.3333)), 
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),

        ]
        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.CenterCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]
        datadir = args.dataset_path    
        

        train_dataset=UCF101FramesDatasetV2(frames_root=os.path.join(datadir,"UCF-101-frames"),
                                            split_txt=os.path.join(datadir,"ucfTrainTestlist/trainlist01.txt"),
                                            class_ind_txt=os.path.join(datadir,"ucfTrainTestlist/classInd.txt"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="train",
                                            channel_first=args.channel_first,
                                            multi_train_sample=args.multi_train_sample,
                                            num_clips=args.train_num_clips,
                                            transform=TwoCropsTransform(transforms.Compose(train_augmentation)))
        knn_train_dataset = UCF101FramesDatasetV2(frames_root=os.path.join(datadir,"UCF-101-frames"),
                                            split_txt=os.path.join(datadir,"ucfTrainTestlist/trainlist01.txt"),
                                            class_ind_txt=os.path.join(datadir,"ucfTrainTestlist/classInd.txt"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="train",
                                            channel_first=args.channel_first,
                                            transform=transforms.Compose(test_augmentation))#knn应该适用于test一样的增强
        test_dataset  = UCF101FramesDatasetV2(frames_root=os.path.join(datadir,"UCF-101-frames"),
                                            split_txt=os.path.join(datadir,"ucfTrainTestlist/testlist01.txt"),
                                            class_ind_txt=os.path.join(datadir,"ucfTrainTestlist/classInd.txt"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="test",
                                            channel_first=args.channel_first,
                                            num_clips=args.num_clips,
                                            transform=transforms.Compose(test_augmentation))                                    
        return train_dataset, knn_train_dataset, test_dataset
    
    else:
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.RandomCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]

        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.CenterCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]
        datadir = args.dataset_path    
 
        train_dataset = UCF101FramesDatasetV2(frames_root=os.path.join(datadir,"UCF-101-frames"),
                                            split_txt=os.path.join(datadir,"ucfTrainTestlist/trainlist01.txt"),
                                            class_ind_txt=os.path.join(datadir,"ucfTrainTestlist/classInd.txt"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="train",
                                            channel_first=args.channel_first,
                                            transform=transforms.Compose(train_augmentation))

        test_dataset  = UCF101FramesDatasetV2(frames_root=os.path.join(datadir,"UCF-101-frames"),
                                            split_txt=os.path.join(datadir,"ucfTrainTestlist/testlist01.txt"),
                                            class_ind_txt=os.path.join(datadir,"ucfTrainTestlist/classInd.txt"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="test",
                                            channel_first=args.channel_first,
                                            num_clips=args.num_clips,
                                            transform=transforms.Compose(test_augmentation))
    
        return train_dataset, test_dataset
    
def get_hmdb51_datasets(args,linearCls=False):  
    
    size = args.size
    if not linearCls:
        
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.RandomResizedCrop(size, scale=(0.2, 0.766), ratio=(0.75, 1.3333)), 
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]
        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.CenterCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]
        
        datadir = args.dataset_path    
        
        train_dataset = HMDB51FramesDatasetV2(
            frames_root=os.path.join(datadir,"HMDB51-frames"),
            split_dir=os.path.join(datadir,"testTrainMulti_7030_splits/"),
            split_num=1,
            num_frames=args.frames_per_clip,
            interval=args.frame_rate,
            mode='train',
            transform=TwoCropsTransform(transforms.Compose(train_augmentation))
        )
        knn_train_dataset = HMDB51FramesDatasetV2(
            frames_root=os.path.join(datadir,"HMDB51-frames"),
            split_dir=os.path.join(datadir,"testTrainMulti_7030_splits/"),
            split_num=1,
            num_frames=args.frames_per_clip,
            interval=args.frame_rate,
            mode='train',
            transform=transforms.Compose(test_augmentation)
        )
        test_dataset = HMDB51FramesDatasetV2(
            frames_root=os.path.join(datadir,"HMDB51-frames"),
            split_dir=os.path.join(datadir,"testTrainMulti_7030_splits/"),
            split_num=1,
            num_frames=args.frames_per_clip,
            interval=args.frame_rate,
            mode='test',
            num_clips=args.num_clips,
            transform=transforms.Compose(test_augmentation)
        )
        return train_dataset, knn_train_dataset, test_dataset
    
    else:
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.RandomCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]

        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size*1.2)),  # TEST_CROP_SIZE: 256
            transforms.CenterCrop(size),  # TRAIN_CROP_SIZE: 224
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
        ]
        datadir = args.dataset_path     
        
        train_dataset = HMDB51FramesDatasetV2(
            frames_root=os.path.join(datadir,"HMDB51-frames"),
            split_dir=os.path.join(datadir,"testTrainMulti_7030_splits/"),
            split_num=1,
            num_frames=args.frames_per_clip,
            interval=args.frame_rate,
            mode='train',
            transform=transforms.Compose(train_augmentation)
        )

        test_dataset = HMDB51FramesDatasetV2(
            frames_root=os.path.join(datadir,"HMDB51-frames"),
            split_dir=os.path.join(datadir,"testTrainMulti_7030_splits/"),
            split_num=1,
            num_frames=args.frames_per_clip,
            interval=args.frame_rate,
            mode='test',
            num_clips=args.num_clips,
            transform=transforms.Compose(test_augmentation)
        )
        
        return train_dataset, test_dataset

def get_kinetics400_datasets(args,linearCls=False):
    size = args.size
 
    if not linearCls:
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.RandomResizedCrop(size, scale=(0.2, 0.766), ratio=(0.75, 1.3333)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.Normalize(
                mean=[0.43216, 0.394666, 0.37645],
                std=[0.22803, 0.22145, 0.216989]),
        ]
        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size * 1.2)),
            transforms.CenterCrop(size),
            transforms.Normalize(
                mean=[0.43216, 0.394666, 0.37645],
                std=[0.22803, 0.22145, 0.216989]),
        ]
        datadir = args.dataset_path   
        train_dataset=Kinetics400FrameLMDBDataset(lmdb_path=os.path.join(datadir,"kinetics_train.lmdb"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="train",
                                            channel_first=args.channel_first,
                                            transform=TwoCropsTransform(transforms.Compose(train_augmentation)))      
        knn_train_dataset=Kinetics400FrameLMDBDataset(lmdb_path=os.path.join(datadir,"kinetics_train.lmdb"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="knn_train",
                                            channel_first=args.channel_first,
                                            transform=transforms.Compose(test_augmentation))                                    
        test_dataset=Kinetics400FrameLMDBDataset(lmdb_path=os.path.join(datadir,"kinetics_val.lmdb"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="val",
                                            channel_first=args.channel_first,
                                            num_clips=args.num_clips,
                                            transform=transforms.Compose(test_augmentation))
 
 
        return train_dataset, knn_train_dataset, test_dataset
    else:
        train_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size * 1.2)),
            transforms.RandomCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
                mean=[0.43216, 0.394666, 0.37645],
                std=[0.22803, 0.22145, 0.216989]),
        ]
        test_augmentation = [
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Resize(int(size * 1.2)),
            transforms.CenterCrop(size),
            transforms.Normalize(
                mean=[0.43216, 0.394666, 0.37645],
                std=[0.22803, 0.22145, 0.216989]),
        ]
        datadir = args.dataset_path  
        train_dataset=Kinetics400FrameLMDBDataset(lmdb_path=os.path.join(datadir,"kinetics_train.lmdb"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="train",
                                            channel_first=args.channel_first,
                                            transform=transforms.Compose(train_augmentation))       
        test_dataset=Kinetics400FrameLMDBDataset(lmdb_path=os.path.join(datadir,"kinetics_val.lmdb"),
                                            num_frames=args.frames_per_clip,
                                            interval=args.frame_rate,
                                            mode="val",
                                            channel_first=args.channel_first,
                                            num_clips=args.num_clips,
                                            transform=transforms.Compose(test_augmentation))                                                  
        return train_dataset, test_dataset

def get_minik400_datasets(args,linearCls=False):
    return get_kinetics400_datasets(args,linearCls)